# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Callable, Iterable, Sequence
import inspect
import types
from typing import Any

from .pytree import PyTreeRegistry as _PyTreeRegistry

class NoSharding:
  def __init__(self) -> None: ...
  def __getstate__(self) -> tuple: ...
  def __setstate__(self, arg: tuple, /) -> None: ...
  def __repr__(self) -> str: ...
  def __eq__(self, arg: object, /) -> bool: ...
  def __hash__(self) -> int: ...

class Chunked:
  def __init__(self, arg: Sequence[int], /) -> None: ...
  def __getstate__(self) -> tuple: ...
  def __setstate__(self, arg: tuple, /) -> None: ...
  @property
  def chunks(self) -> list[int]: ...
  def __repr__(self) -> str: ...
  def __eq__(self, arg: object, /) -> bool: ...

class Unstacked:
  def __init__(self, arg: int, /) -> None: ...
  def __getstate__(self) -> tuple: ...
  def __setstate__(self, arg: tuple, /) -> None: ...
  @property
  def size(self) -> int: ...
  def __repr__(self) -> str: ...
  def __eq__(self, arg: object, /) -> bool: ...

class ShardedAxis:
  def __init__(self, arg: int, /) -> None: ...
  def __getstate__(self) -> tuple: ...
  def __setstate__(self, arg: tuple, /) -> None: ...
  @property
  def axis(self) -> int: ...
  def __repr__(self) -> str: ...
  def __eq__(self, arg: object, /) -> bool: ...

class Replicated:
  def __init__(self, arg: int, /) -> None: ...
  def __getstate__(self) -> tuple: ...
  def __setstate__(self, arg: tuple, /) -> None: ...
  @property
  def replicas(self) -> int: ...
  def __repr__(self) -> str: ...
  def __eq__(self, arg: object, /) -> bool: ...

class ShardingSpec(Any):
  def __init__(self, sharding: Iterable, mesh_mapping: Iterable) -> None: ...
  def __getstate__(self) -> tuple: ...
  def __setstate__(self, arg: tuple, /) -> None: ...
  @property
  def sharding(self) -> tuple[NoSharding | Chunked | Unstacked, ...]: ...
  @property
  def mesh_mapping(self) -> tuple[ShardedAxis | Replicated, ...]: ...
  def __eq__(self, arg: object, /) -> bool: ...
  def __hash__(self) -> int: ...

class PmapFunction:
  def __call__(self, /, *args, **kwargs):
    """Call self as a function."""

  def __get__(self, instance, owner=..., /):
    """Return an attribute of instance, which is of type owner."""
  __vectorcalloffset__: types.MemberDescriptorType = ...

  @property
  def __signature__(self) -> inspect.Signature: ...
  @property
  def _cache_miss(self) -> Callable: ...
  def __getstate__(self) -> dict: ...
  def __setstate__(self, arg: dict, /) -> None: ...
  @property
  def _cache_size(self) -> int: ...
  def _cache_clear(self) -> None: ...
  def _debug_cache_keys(self) -> str: ...

def pmap(
    fun: Callable[..., Any],
    cache_miss: Callable[..., Any],
    static_argnums: Sequence[int],
    shard_arg_fallback: Callable[..., Any],
    pytree_registry: _PyTreeRegistry,
) -> PmapFunction: ...
